; docformat = 'rst'
;
;+
;
; :Purpose:
;   Quasi-Newton optimization
;   
; :Inputs:
;   parameters: hash table containing inversion parameters
;   See a2i_run.pro
;
; :Keywords:
;   emissions_bias: (optional, input) standard deviation of bias to randomly apply to emissions
;   iteration: (optional, input) if greater than 0, will send output to output_{iteration}.sav
;   nProcesses: (optional, input) if greater than 1 will determine sensitivity using multiple processors
;   mc: (optional, input) Monte-Carlo iteration, randomly perturb measurements and priors
;   
; :Outputs:
;   
;   {AIF_DIRECTORY}/{CASE_NAME}/output/output.sav
;   
;   of if iteration greater than 0, then: 
;   {AIF_DIRECTORY}/{CASE_NAME}/output/output_{iteration}.sav
;   
; :Example::
;   a2i_invert, parameters
;   
; :History:
; 	Written by: Matt Rigby, MIT, Sep 6, 2012
;
;-
pro a2i_invert, parameters, emissions_bias=emissions_bias, iteration=iteration, nProcesses=nProcesses, mc=mc, $
  nIterations=nIterations, alpha=alpha, linear=linear, emissions_no_modify=emissions_no_modify

  compile_opt idl2, hidden

  if n_elements(parameters) eq 0 then begin
    print, 'Parameters needed'
    return
  endif

  if n_elements(nIterations) eq 0 then nIt=8 else nIt=nIterations
  
  if n_elements(emissions_bias) gt 0 then print, '!BIASED EMISSIONS!', iteration

  ;SET UP MODEL AND STATE VECTOR
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

  y=a2i_measurements(parameters)

  if ~keyword_set(emissions_no_modify) then begin
    emissions_no_modify=0
  endif
  emissions=a2i_emissions(parameters, y, emissions_bias=emissions_bias, emissions_no_modify=emissions_no_modify)

  if emissions eq -1 then return
  x=a2i_state_define(parameters, y, emissions)

;  ;MAXIMUM ERROR
;  error=x['ERROR']
;  wh=where(error gt 10., count)
;  if count gt 0 then begin
;    error[wh]=10.
;  endif
;  x['ERROR']=error

  model=a2i_model(parameters, y, emissions)
  if keyword_set(mc) then begin
    a2i_mc, parameters, y, x, model, mc=mc
;    model=a2i_model_scale(parameters, y, model, x)
  endif

  model_scaled=mr_hash_copy(model)
  model_ap=mr_hash_copy(model)

  print, string(n_elements(y['WH_MEASURE'])) + ' measurements'
  print, string(n_elements(x['X'])) + ' state vector elements'

  ;Define lifetime arrays
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  lifetime=hash()
  lifetime['POLLUTANT']=intarr(parameters['NPOLLUTANTS']*parameters['TIMESIZE'])
  lifetime['TI']=intarr(parameters['NPOLLUTANTS']*parameters['TIMESIZE'])
  for pi=0, parameters['NPOLLUTANTS']-1 do begin
    lifetime['POLLUTANT', pi*parameters['TIMESIZE'] : (pi+1)*parameters['TIMESIZE']-1]=pi
    lifetime['TI', pi*parameters['TIMESIZE'] :  (pi+1)*parameters['TIMESIZE']-1]=indgen(parameters['TIMESIZE'])
  endfor


  ;REFERENCE RUN
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  y['MODEL']=a2i_model_run(parameters, y, model, lifetime, restart=restart, out_lifetime=lifetime_ref)
  lifetime['MODEL']=temporary(lifetime_ref)

  lifetime['AP']=lifetime['MODEL']
  y['AP']=y['MODEL']
  
  ;INVERSION MATRICES
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  R1=diag_matrix(y['ERROR', Y['WH_MEASURE']]^(-2.))

  if x['WH', 0] ge 0 then begin
    P1=diag_matrix(x['ERROR', x['WH']]^(-2.))
    P1_row_scale=rebin(x['ERROR', x['WH']]^(-2), [n_elements(x['WH'])])
  endif
  x_ap=x['AP']
  y_i=y['Y', Y['WH_MEASURE']]
  R1_column_scale=rebin(y['ERROR', y['WH_MEASURE']]^(-2), [n_elements(y['WH_MEASURE']), x['NSTATE']])

  ;Growth matrices
  if nIt gt 0 then begin
    a2i_growth, parameters, x, emissions, y
  endif

  ;BEGIN INVERSION
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  if n_elements(alpha) eq 0 then alpha=0.8

  if nIt gt 0 then J=fltarr(nIt)
  
  print, strcompress('Beginning ' + string(nIt) + ' iterations...')
  
  X_i=dblarr(X['NSTATE'])

  end_early=0
  it=0
  yRec=!null
  xRec=!null

  if nIt gt 0 then begin
  repeat begin
    
    print, 'Calculating sensitivity...' & t0=systime(/seconds)

    if nProcesses gt 1 then begin
      H=a2i_sensitivity_parallel(parameters, y, model, x, lifetime, restart, H_lifetime=H_lifetime, $
        nProcesses=nProcesses, create_processes=(it eq 0), destroy_processes=(it eq (nIt-1)), oBridge=oBridge)
    endif else begin
      H=a2i_sensitivity(parameters, y, model, x, lifetime, restart, H_lifetime=H_lifetime)
    endelse

    print, '... done. ' + string((systime(/seconds) - t0)/60., format='(f4.1)') + ' min' & $
      t0=systime(/seconds)

    ;Cost function derivatives:
    ;d2Sdx2=(transpose(H)##R1##H + transpose(D)##S1##D + P1)
    ;dSdx=(transpose(H)##R1##(y_ref - y) + transpose(D)##S1##(D##x - x_growth) + P1##(x - x_ap))
    
    if parameters['FIRN'] then begin
      
      x_i=a2i_firn(parameters, x, x_i, y, y_i, H, P1, P1_row_scale, R1_column_scale, alpha, P)

    endif else begin
  
      Hr=H[*, Y['WH_MEASURE']]
      ;Efficient multiply diagonal matrix: HTR1=transpose(Hr)##R1
      HTR1=transpose(Hr)*R1_column_scale
  
      d2Sdx2=HTR1##Hr ;+P1
      dSdx=HTR1##(y['MODEL', Y['WH_MEASURE']] - y_i); + P1_row_scale*x_i
  
      if x['WH', 0] ge 0 then begin
        d2Sdx2[x['WH2D']]+=P1
        dSdx[x['WH']]+=P1_row_scale*x_i[x['WH']]
      endif
      
      if x.haskey('GROWTH_WH') then begin
        if x['GROWTH_WH', 0] ge 0 then begin
          d2sdx2[x['GROWTH_WH2D']]+=x['DTS1D']
          dSdx[x['GROWTH_WH']]+=x['DTS1']##(x['D']##x_i[x['GROWTH_WH']] - x['GROWTH'])
        endif
      endif

      if x.haskey('GROWTH_WH_L') then begin
        if x['GROWTH_WH_L', 0] ge 0 then begin
          d2sdx2[x['GROWTH_WH2D_L']]+=x['DTS1D_L']
          dSdx[x['GROWTH_WH_L']]+=x['DTS1_L']##(x['D_L']##x_i[x['GROWTH_WH_L']] - x['GROWTH_L'])
        endif
      endif
      
      P=invert(temporary(d2sdx2))
      x_i-=alpha*(P##dSdx)
      
    endelse


    print, 'State vector updated. ' + string((systime(/seconds) - t0)/60., format='(f4.1)') + ' min'
    
    if total(finite(x_i) eq 0) gt 0. or total(abs(x_i) gt 1.e10) gt 0. then begin
      print, 'AIF_INVERT: infinite x'
      if nProcesses gt 1 then begin
        for proI=0, nProcesses-1 do begin
          obj_destroy, oBridge[proI]
        endfor
      endif
      return
    endif

    x_prev=x['X']
    X['X']=X['AP'] + x_i

    ;UPDATED RUN
    ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
    y_prev=y['MODEL']
    model_scaled=a2i_model_scale(parameters, y, model, x)
    y['MODEL']=a2i_model_run(parameters, y, model_scaled, lifetime, $
      restart=restart, out_lifetime=lifetime_ref)
    lifetime['MODEL']=temporary(lifetime_ref)

    residual=(y['MODEL', y['WH_MEASURE']] - y['Y', y['WH_MEASURE']])
    J[it]=total(residual^2/y['ERROR', y['WH_MEASURE']]^2)
      if x['WH', 0] ge 0 then begin
        J[it]+=total(( (x['X', x['WH']] - x['AP', x['WH']])/x['ERROR', x['WH']] )^2)
      endif
      if x.haskey('GROWTH_WH') then begin
        if x['GROWTH_WH', 0] ge 0 then begin
          J[it]+=total(( (x['D']##x['X', x['GROWTH_WH']])/x['GROWTH_ERROR'] )^2)
        endif
      endif
      if x.haskey('GROWTH_WH_L') then begin
        if x['GROWTH_WH_L', 0] ge 0 then begin
          J[it]+=total(( (x['D_L']##x['X', x['GROWTH_WH_L']])/x['GROWTH_ERROR_L'] )^2)
        endif
      endif
;      total(x_i^2/x['ERROR']^2); + $
;      total(((D##x_i)[x['GROWTH_WH']])^2/x['GROWTH_ERROR']^2)
      ;##P1##transpose(x_i)
      ;##R1##transpose(residual)

    ;Abort if J grows!
    if it gt 1 then begin
      if J[it] gt 1.1*J[it-1] then begin
        print, 'AIF_INVERT: Cost function growing!'
        if nProcesses gt 1 then begin
          for proI=0, nProcesses-1 do begin
            obj_destroy, oBridge[proI]
          endfor
        endif
        return
      endif

      if abs((J[it] - j[it-1])/J[0]) lt 0.0001 then begin
        end_early=1
        if nProcesses gt 1 then begin
          for proI=0, nProcesses-1 do begin
            obj_destroy, oBridge[proI]
          endfor
        endif
        print, 'Early exit'
      endif
    endif

    yRec=[[[yRec]], [y['MODEL']]]
    xRec=[[[xRec]], [x['X']]]
    
    print, '... done iteration ' + string(it+1) + ' of ' + string(nIt)

    it++

  endrep until (end_early) or (it ge nIt)
  endif
  
  model=mr_hash_copy(model_scaled)

  a2i_lifetime, parameters, lifetime, model

  if end_early then J=J[0:it-1]

  
  ;Reduce sizes for storage
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  
  if nIt gt 0 then begin

    x['REC']=transpose(xRec)
    y['REC']=transpose(yRec)
    
    H=float(H)
    H_lifetime=float(H_lifetime)
    P=float(P)
    
    Hwh=where(abs(H) gt 1.e-7)
  ;  Hwh=where(H ne 0.)
    Hsp=H[Hwh]
  
    H_lifetimeWh=where(abs(H_lifetime) gt 1.e-7)
  ;  H_lifetimeWh=where(H_lifetime ne 0.)
    H_lifetimeSp=H[H_lifetimewh]
  
    Pwh=where(abs(P) gt 1.e-7)
  ;  Pwh=where(P ne 0.)
    Psp=P[Pwh]
  endif

  ;Put firn measurements in y matrix
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  
  if parameters['FIRN'] then begin
    
    a2i_firn_effective_date, parameters, y
    
    for fi=0, n_elements(y['FIRN'])-1 do begin
      wh=where(y['TI'] eq fix(y['FIRN_TI', fi]) and $
        y['BOX'] eq y['FIRN_BOX', fi] and $
        y['POLLUTANT'] eq y['FIRN_POLLUTANT', fi])
      y['Y', wh]=y['FIRN', fi]
      y['ERROR', wh]=y['FIRN_ERROR', fi]
    endfor
  endif


  ;Save
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  
  if n_elements(emissions_bias) gt 0 then begin
    save, filename=a2i_filestr(/input, parameters['CASE_NAME'] + $
      '/output/output_' + string(iteration, format='(I03)') + '.sav'), $
      parameters, model, emissions, y, x, lifetime, J, alpha
  endif else begin
;    if nIt gt 0 and y['NMEASUREMENTS'] lt 10000L then begin
;      y['MODEL_ERROR']=sqrt(diag_matrix(H##P##transpose(H)))
;    endif

    save, filename=a2i_filestr(/input, parameters['CASE_NAME'] + '/output/output.sav'), $
      parameters, model, model_ap, y, x, lifetime, Hsp, Hwh, Psp, Pwh, $
      H_lifetimeSp, H_lifetimeWh, J, alpha
  endelse

end